{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://github.com/mymusise/ChatGLM-Tuning.git\n",
    "%cd  ChatGLM-Tuning\n",
    "!pip install -r requirements.txt "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mymusise/pro/stable-diffusion-webui/venv/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n",
      "Loading checkpoint shards: 100%|██████████| 8/8 [00:05<00:00,  1.46it/s]\n"
     ]
    }
   ],
   "source": [
    "from modeling_chatglm import ChatGLMForConditionalGeneration\n",
    "import torch\n",
    "\n",
    "\n",
    "torch.set_default_tensor_type(torch.cuda.HalfTensor)\n",
    "model = ChatGLMForConditionalGeneration.from_pretrained(\"THUDM/chatglm-6b\", trust_remote_code=True, device_map='auto')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===================================BUG REPORT===================================\n",
      "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
      "================================================================================\n",
      "CUDA SETUP: CUDA runtime path found: /usr/local/cuda-11.5/lib64/libcudart.so\n",
      "CUDA SETUP: Highest compute capability among GPUs detected: 8.6\n",
      "CUDA SETUP: Detected CUDA version 115\n",
      "CUDA SETUP: Loading binary /home/mymusise/pro/stable-diffusion-webui/venv/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cuda115.so...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading (…)/adapter_config.json: 100%|██████████| 381/381 [00:00<00:00, 73.8kB/s]\n",
      "Downloading (…)\"adapter_model.bin\";: 100%|██████████| 7.36M/7.36M [00:01<00:00, 3.95MB/s]\n"
     ]
    }
   ],
   "source": [
    "from peft import PeftModel\n",
    "\n",
    "model = PeftModel.from_pretrained(model, \"mymusise/chatGLM-6B-alpaca-lora\")\n",
    "torch.set_default_tensor_type(torch.cuda.FloatTensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"THUDM/chatglm-6b\", trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Instruction: Give three tips for staying healthy.\n",
      "Answer: 1. Eat a balanced diet of fruits, vegetables, lean protein, and whole grains.\n",
      "2. Get regular exercise, such as walking, running, or swimming.\n",
      "3. Stay hydrated by drinking plenty of water.\n",
      "### 1.Answer:\n",
      " 1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n",
      "2. Exercise regularly to keep your body active and strong. \n",
      "3. Get enough sleep and maintain a consistent sleep schedule. \n",
      "\n",
      "\n",
      "Instruction: What are the three primary colors?\n",
      "Answer: The three primary colors are red, blue, and yellow.\n",
      "### 2.Answer:\n",
      " The three primary colors are red, blue, and yellow. \n",
      "\n",
      "\n",
      "Instruction: 请给出三个健康生活的建议\n",
      "Answer: 1. 坚持运动,每周至少进行150分钟的中等强度有氧运动。 \n",
      "2. 饮食健康,多吃蔬菜、水果、全谷物、蛋白质和健康脂肪。 \n",
      "3. 睡眠充足,每天保证7-9小时的睡眠时间。\n",
      "### 3.Answer:\n",
      " 1.保持均衡的饮食：饮食是影响身体健康的一个重要因素。\n",
      "2.定期运动：运动对身体健康有很多好处。\n",
      "3.睡眠充足：良好的睡眠质量对身体健康至关重要。 \n",
      "\n",
      "\n",
      "Instruction: 三原色是什么?(请用中文)\n",
      "Answer: 三原色是指红、黄、蓝三种颜色。\n",
      "### 4.Answer:\n",
      " 三原色是指色彩三基色，即红色、绿色和蓝色。在色彩学中，通过不同比例的三原色的叠加可以形成各种色彩，这也是彩色显示技术和印刷技术的基础 \n",
      "\n",
      "\n",
      "Instruction: 以下哪家公司类型与其它的不同\n",
      "Input: 微博、京东、淘宝\n",
      "Answer: 微博是社交媒体平台,而淘宝和京东是电商平台。\n",
      "### 5.Answer:\n",
      " 微博 \n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from cover_alpaca2jsonl import format_example\n",
    "\n",
    "# alpaca数据集\n",
    "instructions = [\n",
    " {'instruction': 'Give three tips for staying healthy.',\n",
    "  'input': '',\n",
    "  'output': '1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \\n2. Exercise regularly to keep your body active and strong. \\n3. Get enough sleep and maintain a consistent sleep schedule.',\n",
    " },\n",
    " {'instruction': 'What are the three primary colors?',\n",
    "  'input': '',\n",
    "  'output': 'The three primary colors are red, blue, and yellow.',\n",
    " }]\n",
    "\n",
    "# alpaca数据集，问题翻译为中文，output为chatgpt的输出\n",
    "instructions += [\n",
    " {'instruction': '请给出三个健康生活的建议',\n",
    "  'input': '',\n",
    "  'output': '1.保持均衡的饮食：饮食是影响身体健康的一个重要因素。\\n2.定期运动：运动对身体健康有很多好处。\\n3.睡眠充足：良好的睡眠质量对身体健康至关重要。',\n",
    " },\n",
    " {'instruction': '三原色是什么？(请用中文)',\n",
    "  'input': '',\n",
    "  'output': '三原色是指色彩三基色，即红色、绿色和蓝色。在色彩学中，通过不同比例的三原色的叠加可以形成各种色彩，这也是彩色显示技术和印刷技术的基础',\n",
    " },\n",
    " {'instruction': '以下哪家公司类型与其它的不同',\n",
    "  'input': '微博、京东、淘宝',\n",
    "  'output': '微博',\n",
    " }\n",
    "]\n",
    "with torch.no_grad():\n",
    "    for idx, item in enumerate(instructions):\n",
    "        feature = format_example(item)\n",
    "        input_text = feature['context']\n",
    "        ids = tokenizer.encode(input_text)\n",
    "        input_ids = torch.LongTensor([ids])\n",
    "        out = model.generate(\n",
    "            input_ids=input_ids,\n",
    "            max_length=150,\n",
    "            do_sample=False,\n",
    "            temperature=0\n",
    "        )\n",
    "        out_text = tokenizer.decode(out[0])\n",
    "        answer = out_text.replace(input_text, \"\").replace(\"\\nEND\", \"\").strip()\n",
    "        item['infer_answer'] = answer\n",
    "        print(out_text)\n",
    "        print(f\"### {idx+1}.Answer:\\n\", item.get('output'), '\\n\\n')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "25273a2a68c96ebac13d7fb9e0db516f9be0772777a0507fe06d682a441a3ba7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
